import torch
import torch.nn as nn


class SoftMax2dArea(nn.Module):
    def __init__(self):
        super(SoftMax2dArea, self).__init__()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        b, c, w, h = x.size()
        #print(x)
        x = x.reshape(b, -1)
        y = self.softmax(x)
        y = y.reshape(b, c, w, h)
        #print(y)
        return y

if __name__ == '__main__':
    model = SoftMax2dArea()
    x = torch.randn((9,1,3,3))
    print(model(x))
